{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "from matplotlib import pyplot as plt\n",
    "from res.plot_lib import set_default"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set style (needs to be in a new cell)\n",
    "set_default(figsize=(16, 8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training set\n",
    "m = 20  # nb of training pairs\n",
    "x = (torch.rand(m) - 0.5) * 12  # inputs, sampled from -5 to +5\n",
    "y = x * torch.sin(x)  # targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# View training points\n",
    "plt.plot(x.numpy(), y.numpy(), 'o')\n",
    "plt.axis('equal')\n",
    "plt.ylim([-10, 5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define network architecture (try different non-linearities)\n",
    "\n",
    "non_linear = nn.Tanh\n",
    "non_linear = nn.ReLU\n",
    "\n",
    "net = nn.Sequential(\n",
    "    nn.Dropout(p=0.05),\n",
    "    nn.Linear(1, 20),\n",
    "    non_linear(),\n",
    "    nn.Dropout(p=0.05),\n",
    "    nn.Linear(20, 20),\n",
    "    non_linear(),\n",
    "    nn.Linear(20, 1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training objective and optimiser\n",
    "criterion = nn.MSELoss()\n",
    "optimiser = optim.SGD(net.parameters(), lr=0.01, weight_decay=0.00001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training loop\n",
    "for epoch in range(1000):\n",
    "    y_hat = net(x.view(-1, 1))\n",
    "    loss = criterion(y_hat, y.view(-1, 1))\n",
    "    optimiser.zero_grad()\n",
    "    loss.backward()\n",
    "    optimiser.step()\n",
    "#     print(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a denser input range\n",
    "xx = torch.linspace(-15, 15, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate net over denser input (try both eval() and train() modes)\n",
    "\n",
    "net.eval()\n",
    "# net.train()\n",
    "\n",
    "with torch.no_grad():\n",
    "    plt.plot(xx.numpy(), net(xx.view(-1, 1)).squeeze().numpy(), 'C1')\n",
    "plt.plot(x.numpy(), y.numpy(), 'oC0')\n",
    "plt.axis('equal')\n",
    "plt.ylim([-10, 5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Multiple (100) runs for denser input\n",
    "net.train()\n",
    "y_hat = list()\n",
    "with torch.no_grad():\n",
    "    for t in range(100):\n",
    "        y_hat.append(net(xx.view(-1, 1)).squeeze())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate mean and std over denser input\n",
    "y_hat = torch.stack(y_hat)\n",
    "mean = y_hat.mean(0)\n",
    "std = y_hat.std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualise mean and mean ± std -> confidence range\n",
    "plt.plot(xx.numpy(), mean.numpy(), 'C1')\n",
    "plt.fill_between(xx.numpy(), (mean + std).numpy(), (mean - std).numpy(), color='C2')\n",
    "plt.plot(x.numpy(), y.numpy(), 'oC0')\n",
    "plt.axis('equal')\n",
    "plt.ylim([-10, 5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dl-minicourse] *",
   "language": "python",
   "name": "conda-env-dl-minicourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
